{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b627f6ce",
   "metadata": {},
   "source": [
    "# Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "48986070",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from getpass import getpass\n",
    "from urllib.request import urlopen\n",
    "\n",
    "from elasticsearch import Elasticsearch, helpers\n",
    "from langchain.text_splitter import CharacterTextSplitter\n",
    "from langchain.vectorstores import ElasticsearchStore\n",
    "from langchain import HuggingFacePipeline\n",
    "from langchain.chains import RetrievalQA\n",
    "from langchain.prompts import ChatPromptTemplate\n",
    "from langchain.schema.output_parser import StrOutputParser\n",
    "from langchain.schema.runnable import RunnablePassthrough\n",
    "from huggingface_hub import login\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from transformers import AutoTokenizer, pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25cc9732",
   "metadata": {},
   "source": [
    "# Get environment variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6e09f549",
   "metadata": {},
   "outputs": [],
   "source": [
    "HUGGING_FACE_KEY = os.getenv(\"HUGGING_FACE_KEY\")\n",
    "ES_USER = os.getenv(\"ES_USER\")\n",
    "ES_PASSWORD = os.getenv(\"ES_PASSWORD\")\n",
    "ES_FINGERPRINT = os.getenv(\"ES_FINGERPRINT\")\n",
    "elastic_index_name='gemma-rag'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b72f136f",
   "metadata": {},
   "source": [
    "# Add documents"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0e4bbab",
   "metadata": {},
   "source": [
    "## Read sample dataset and deserialize the documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "92ac66f5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully loaded 15 documents\n"
     ]
    }
   ],
   "source": [
    "# Load data into a JSON object\n",
    "with open('workplace-docs.json') as f:\n",
    "   workplace_docs = json.load(f)\n",
    " \n",
    "print(f\"Successfully loaded {len(workplace_docs)} documents\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d55426e",
   "metadata": {},
   "source": [
    "## Split Documents into Passages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "36663054",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Created a chunk of size 245, which is longer than the specified 50\n",
      "Created a chunk of size 288, which is longer than the specified 50\n",
      "Created a chunk of size 204, which is longer than the specified 50\n",
      "Created a chunk of size 281, which is longer than the specified 50\n",
      "Created a chunk of size 249, which is longer than the specified 50\n",
      "Created a chunk of size 285, which is longer than the specified 50\n",
      "Created a chunk of size 298, which is longer than the specified 50\n",
      "Created a chunk of size 270, which is longer than the specified 50\n",
      "Created a chunk of size 224, which is longer than the specified 50\n",
      "Created a chunk of size 288, which is longer than the specified 50\n",
      "Created a chunk of size 260, which is longer than the specified 50\n",
      "Created a chunk of size 199, which is longer than the specified 50\n",
      "Created a chunk of size 290, which is longer than the specified 50\n",
      "Created a chunk of size 251, which is longer than the specified 50\n",
      "Created a chunk of size 195, which is longer than the specified 50\n",
      "Created a chunk of size 242, which is longer than the specified 50\n",
      "Created a chunk of size 275, which is longer than the specified 50\n",
      "Created a chunk of size 277, which is longer than the specified 50\n",
      "Created a chunk of size 284, which is longer than the specified 50\n",
      "Created a chunk of size 285, which is longer than the specified 50\n",
      "Created a chunk of size 324, which is longer than the specified 50\n",
      "Created a chunk of size 328, which is longer than the specified 50\n",
      "Created a chunk of size 280, which is longer than the specified 50\n",
      "Created a chunk of size 123, which is longer than the specified 50\n",
      "Created a chunk of size 221, which is longer than the specified 50\n",
      "Created a chunk of size 330, which is longer than the specified 50\n",
      "Created a chunk of size 292, which is longer than the specified 50\n",
      "Created a chunk of size 240, which is longer than the specified 50\n",
      "Created a chunk of size 232, which is longer than the specified 50\n",
      "Created a chunk of size 324, which is longer than the specified 50\n",
      "Created a chunk of size 167, which is longer than the specified 50\n",
      "Created a chunk of size 181, which is longer than the specified 50\n",
      "Created a chunk of size 353, which is longer than the specified 50\n",
      "Created a chunk of size 471, which is longer than the specified 50\n",
      "Created a chunk of size 400, which is longer than the specified 50\n",
      "Created a chunk of size 607, which is longer than the specified 50\n",
      "Created a chunk of size 440, which is longer than the specified 50\n",
      "Created a chunk of size 788, which is longer than the specified 50\n",
      "Created a chunk of size 547, which is longer than the specified 50\n",
      "Created a chunk of size 67, which is longer than the specified 50\n",
      "Created a chunk of size 635, which is longer than the specified 50\n",
      "Created a chunk of size 440, which is longer than the specified 50\n",
      "Created a chunk of size 464, which is longer than the specified 50\n",
      "Created a chunk of size 372, which is longer than the specified 50\n",
      "Created a chunk of size 866, which is longer than the specified 50\n",
      "Created a chunk of size 111, which is longer than the specified 50\n",
      "Created a chunk of size 361, which is longer than the specified 50\n",
      "Created a chunk of size 110, which is longer than the specified 50\n",
      "Created a chunk of size 471, which is longer than the specified 50\n",
      "Created a chunk of size 288, which is longer than the specified 50\n",
      "Created a chunk of size 150, which is longer than the specified 50\n",
      "Created a chunk of size 62, which is longer than the specified 50\n",
      "Created a chunk of size 619, which is longer than the specified 50\n",
      "Created a chunk of size 257, which is longer than the specified 50\n",
      "Created a chunk of size 310, which is longer than the specified 50\n",
      "Created a chunk of size 310, which is longer than the specified 50\n",
      "Created a chunk of size 282, which is longer than the specified 50\n",
      "Created a chunk of size 230, which is longer than the specified 50\n",
      "Created a chunk of size 213, which is longer than the specified 50\n",
      "Created a chunk of size 395, which is longer than the specified 50\n",
      "Created a chunk of size 270, which is longer than the specified 50\n",
      "Created a chunk of size 224, which is longer than the specified 50\n",
      "Created a chunk of size 292, which is longer than the specified 50\n",
      "Created a chunk of size 231, which is longer than the specified 50\n",
      "Created a chunk of size 153, which is longer than the specified 50\n",
      "Created a chunk of size 277, which is longer than the specified 50\n",
      "Created a chunk of size 171, which is longer than the specified 50\n",
      "Created a chunk of size 106, which is longer than the specified 50\n",
      "Created a chunk of size 253, which is longer than the specified 50\n",
      "Created a chunk of size 140, which is longer than the specified 50\n",
      "Created a chunk of size 323, which is longer than the specified 50\n",
      "Created a chunk of size 95, which is longer than the specified 50\n",
      "Created a chunk of size 201, which is longer than the specified 50\n",
      "Created a chunk of size 174, which is longer than the specified 50\n",
      "Created a chunk of size 232, which is longer than the specified 50\n",
      "Created a chunk of size 238, which is longer than the specified 50\n",
      "Created a chunk of size 193, which is longer than the specified 50\n",
      "Created a chunk of size 305, which is longer than the specified 50\n",
      "Created a chunk of size 1120, which is longer than the specified 50\n",
      "Created a chunk of size 397, which is longer than the specified 50\n",
      "Created a chunk of size 320, which is longer than the specified 50\n",
      "Created a chunk of size 567, which is longer than the specified 50\n",
      "Created a chunk of size 300, which is longer than the specified 50\n",
      "Created a chunk of size 217, which is longer than the specified 50\n",
      "Created a chunk of size 271, which is longer than the specified 50\n",
      "Created a chunk of size 264, which is longer than the specified 50\n",
      "Created a chunk of size 340, which is longer than the specified 50\n",
      "Created a chunk of size 383, which is longer than the specified 50\n",
      "Created a chunk of size 243, which is longer than the specified 50\n",
      "Created a chunk of size 344, which is longer than the specified 50\n",
      "Created a chunk of size 225, which is longer than the specified 50\n",
      "Created a chunk of size 208, which is longer than the specified 50\n",
      "Created a chunk of size 420, which is longer than the specified 50\n",
      "Created a chunk of size 89, which is longer than the specified 50\n",
      "Created a chunk of size 94, which is longer than the specified 50\n",
      "Created a chunk of size 299, which is longer than the specified 50\n",
      "Created a chunk of size 310, which is longer than the specified 50\n"
     ]
    }
   ],
   "source": [
    "metadata = []\n",
    "content = []\n",
    "\n",
    "for doc in workplace_docs:\n",
    "    content.append(doc[\"content\"])\n",
    "    metadata.append(\n",
    "        {\n",
    "            \"name\": doc[\"name\"],\n",
    "            \"summary\": doc[\"summary\"],\n",
    "            \"rolePermissions\": doc[\"rolePermissions\"],\n",
    "        }\n",
    "    )\n",
    "\n",
    "text_splitter = CharacterTextSplitter(chunk_size=50, chunk_overlap=0)\n",
    "docs = text_splitter.create_documents(content, metadatas=metadata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f1af829b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'name': 'liuxgm.local', 'cluster_name': 'elasticsearch', 'cluster_uuid': '1FpGdI1WRcaqlKk01d0dYw', 'version': {'number': '8.12.0', 'build_flavor': 'default', 'build_type': 'tar', 'build_hash': '1665f706fd9354802c02146c1e6b5c0fbcddfbc9', 'build_date': '2024-01-11T10:05:27.953830042Z', 'build_snapshot': False, 'lucene_version': '9.9.1', 'minimum_wire_compatibility_version': '7.17.0', 'minimum_index_compatibility_version': '7.0.0'}, 'tagline': 'You Know, for Search'}\n"
     ]
    }
   ],
   "source": [
    "url = f\"https://{ES_USER}:{ES_PASSWORD}@localhost:9200\"\n",
    "client = Elasticsearch(url, ca_certs = \"./http_ca.crt\", verify_certs = True)\n",
    "print(client.info())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1ebec8ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "es = ElasticsearchStore.from_documents( \n",
    "                            docs,\n",
    "                            strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(model_id=\".elser_model_2\"),\n",
    "                            es_url = url, \n",
    "                            es_connection = client,\n",
    "                            index_name = elastic_index_name, \n",
    "                            es_user = ES_USER,\n",
    "                            es_password = ES_PASSWORD)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a55ec462-1632-4d98-bda8-e1435e3597b6",
   "metadata": {},
   "source": [
    "# Hugging face login"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fae85c8d-2642-4010-bc76-e94d21ffa98d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8f46242ce5d54726a22ed269a2eb16b2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f01121bd-c124-446a-a953-f7c4437a43ee",
   "metadata": {},
   "source": [
    "## Initialize the tokenizer with the model (google/gemma-2b-it)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fa8153b6-31c7-4ae3-9aa7-feb65a8d8a4b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ce1ec1a5fe4044ca87b3e784b663cdd2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9a886dc058724ba6adf837c5b84f349e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "94fc528af57b48138a6bce19da0a00cb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5db2fb01a77345da8eed46952af036b6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aa0edaf235a641c898845b2d3e48f2e8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7a8f15e9597844a1a0c105237737fd22",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6f59648397ad4c689773a98fe0fed1be",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "11419fba00714b87bd5f4f4d81e089e6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/2.16k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "87ba1112e05f44689918d5e295b90b17",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "10f5fce50c8b4efaac332efc1573dcc1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f22c9cfde0ee4de98d26f8b688b970ff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/888 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\"google/gemma-2b-it\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2b-it\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b7c4e65-cdac-46bd-a8a2-54c702eaba38",
   "metadata": {},
   "source": [
    "# Create a text-generation pipeline and initialize with LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "9705e8c8-9a01-4769-90f5-470f2aa97f71",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = pipeline(\n",
    "    \"text-generation\",\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    max_new_tokens=1024,\n",
    ")\n",
    "\n",
    "llm = HuggingFacePipeline(\n",
    "    pipeline=pipe,\n",
    "    model_kwargs={\"temperature\": 0.7},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b1fbab0-1672-4826-bf42-92cc9e4bbe93",
   "metadata": {},
   "source": [
    "# Format Docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "a3cde1f6-5fa0-4368-8381-dd78bd141426",
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_docs(docs):\n",
    "    return \"\\n\\n\".join(doc.page_content for doc in docs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "744a30ee-1151-4da6-9d02-fc9e52f3b236",
   "metadata": {},
   "source": [
    "# Create a chain using Prompt template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "fdbaab6d-621d-4823-b5ac-6521023ea1f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "retriever = es.as_retriever(search_kwargs={\"k\": 10})\n",
    "\n",
    "template = \"\"\"Answer the question based only on the following context:\\n\n",
    "\n",
    "{context}\n",
    "\n",
    "Question: {question}\n",
    "\"\"\"\n",
    "prompt = ChatPromptTemplate.from_template(template)\n",
    "\n",
    "\n",
    "chain = (\n",
    "    {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
    "    | prompt\n",
    "    | llm\n",
    "    | StrOutputParser()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3dea0fce-1f9b-4323-93c4-08c034a53595",
   "metadata": {},
   "source": [
    "# Ask question"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "25539e18-bbc5-4cb4-a110-b1d73a853c67",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Answer: The pet policy in the office allows employees to bring pets to the office, subject to approval by the HR department. Pets covered under this policy include dogs, cats, and other small, non-exotic animals, subject to approval by the HR department.'"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chain.invoke(\"What is the pet policy in the office?\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test1",
   "language": "python",
   "name": "test1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
